"""
pytorch -- oneflow 格式权重转换
"""
from kamal import vision
import torch
import argparse
import oneflow as flow
import pdb


parser = argparse.ArgumentParser()
parser.add_argument( '--car_ckpt', required=True )
parser.add_argument( '--dog_ckpt' )
args = parser.parse_args()

cars_parameters = torch.load(args.car_ckpt).state_dict()
# dogs_parameters = torch.load(args.dog_ckpt).state_dict()
cars_para, dogs_para = {}, {}
# pdb.set_trace() 
for key, value in cars_parameters.items():
    val = value.detach().cpu().numpy()
    if not str(key).endswith('num_batches_tracked'):
        cars_para[key] = val

# for key, value in dogs_parameters.items():
#     val = value.detach().cpu().numpy()
#     if not str(key).endswith('num_batches_tracked'):
#         dogs_para[key] = val

car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
# dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False)


car_teacher.load_state_dict(cars_para)
# dog_teacher.load_state_dict(dogs_para)

# torch.save(car_teacher, 'ckpt/aircraft_res50.pth')
# torch.save(dog_teacher, 'checkpoint/sun_res50.pth')

flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model")
# flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model")